Libraries

Set random seed for reproducibility

set.seed(1234)
library(dplyr)
library(lubridate)
library(lime)       # ML local interpretation
library(vip)        # ML global interpretation
library(pdp)        # ML global interpretation
library(ggplot2)    # visualization pkg leveraged by above packages
library(caret)      # ML model building

Data

Read in data

all.df <- read.csv("./data/all.df.csv")

Convert dates

all.df$dot <- ymd(all.df$dot)
all.df$dor <- ymd(all.df$dor)
all.df$bdate <- ymd(all.df$bdate)
all.df$pdate <- ymd(all.df$pdate)

Convert all character strings to factors

all.df <- all.df %>% mutate_if(is.character,as.factor)

Make outcome a binary variable (0/1 relapse)

all.df$rbin <- factor(all.df$rbin, levels = c("yes", "no"))

Filter out any tests that are post-relapse

all.df <- all.df[which(all.df$bdate < all.df$dor | is.na(all.df$dor)), ]

Filter out relapse >720 days

all.df <- all.df[which(all.df$rbin == "no" | all.df$rtime < 720),]

Filter out any missing tests

all.df <- all.df[!is.na(all.df$bmc_cdw) & !is.na(all.df$bmc_cd3) & 
                   !is.na(all.df$bmc_cd15) & !is.na(all.df$bmc_cd34) &
                   !is.na(all.df$pbc_cdw) & !is.na(all.df$pbc_cd3) & 
                   !is.na(all.df$pbc_cd15) & !is.na(all.df$pbc_cd34),]
all.df <<- all.df

Set up for LIME plots

Sub out the required data

rbin ~ txage + sex + rstatprtx + hla + tbi + gmgp + agvhd + cgvhd + bmc_cdw + bmc_cd3 + bmc_cd15 + bmc_cd34 + pbc_cdw + pbc_cd3 + pbc_cd15 + pbc_cd34

all.df2 <<- all.df %>%
  select(rbin, txage, sex, rstatprtx, hla, 
         tbi, gmgp, agvhd, cgvhd, ## Removed e as only 1 level
         bmc_cdw, bmc_cd3, bmc_cd15, bmc_cd34, 
         pbc_cdw, pbc_cd3, pbc_cd15, pbc_cd34)

Set up random forest through caret

# fit.caret <- train(
#   rbin ~ ., 
#   data = all.df2, 
#   method = 'rf',
#   trControl = trainControl(method = "cv", number = 5, classProbs = TRUE),
#   tuneLength = 1
# )

fit.caret <- train(
  rbin ~ ., 
  data = all.df2, 
  method = 'rf'
)
fit.caret
## Random Forest 
## 
## 140 samples
##  16 predictor
##   2 classes: 'yes', 'no' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 140, 140, 140, 140, 140, 140, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.8799954  0.5983077
##   10    0.9034804  0.7030361
##   19    0.9097140  0.7297991
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 19.

Optional rf model –> probably not needed

fit.rf <- randomForest::randomForest(
  rbin ~ ., 
  data = all.df2)

variable importance plots

vip(fit.rf, method = "permute", data = all.df2, target = "rbin", 
    metric = "auc", pred_wrapper = predict, 
    reference_class = "no", nsim = 10) + ggtitle("ranger: RF")

# vis <- vi(fit.ranger, method = "permute", data = all.df2, target = "rbin", 
#           metric = "auc", pred_wrapper = pfun, 
#           reference_class = "no", nsim = 100) 
# vip(vis, geom = "boxplot") # Figure 12

vis <- vi(fit.rf, method = "permute", data = all.df2, target = "rbin", 
          metric = "auc", pred_wrapper = predict, 
          reference_class = "no", nsim = 100) 
vip(vis, geom = "boxplot") # Figure 12

p <- ggplot(vis, aes(Variable, Importance))
p  + 
  geom_bar(stat="identity", color="black", 
           position=position_dodge()) + 
  geom_errorbar(aes(ymin = Importance-StDev, 
                    ymax = Importance+StDev), width = 0.2) +
  coord_flip()

p <- ggplot(vis, aes(reorder(Variable, Importance), Importance)) + 
  geom_bar(stat="identity", color="black", 
           position=position_dodge()) + 
  geom_errorbar(aes(ymin = Importance-StDev, 
                    ymax = Importance+StDev), width = 0.2) +
  coord_flip() + theme_bw() + scale_x_discrete(name = "Variable")
print(p)

#ggsave("./lime_plots/vip.pdf", plot = p)

Explainers

explainer_caret <- lime(all.df2, fit.caret, n_bins = 5)
# explainer_rf <- lime(all.df2, fit.rf, n_bins = 5)

summary(explainer_caret)
##                      Length Class  Mode     
## model                24     train  list     
## preprocess            1     -none- function 
## bin_continuous        1     -none- logical  
## n_bins                1     -none- numeric  
## quantile_bins         1     -none- logical  
## use_density           1     -none- logical  
## feature_type         17     -none- character
## bin_cuts             17     -none- list     
## feature_distribution 17     -none- list
# summary(explainer_rf)

Example explainer plot for patient 1

patientID <- which(all.df$ID == 1)

explanation_caret <- explain(
  x = all.df2[patientID,], 
  explainer = explainer_caret, 
  n_permutations = 5000,
  dist_fun = "gower",
  kernel_width = .75,
  n_features = 10, 
  feature_select = "highest_weights",
  labels = "yes"
)

p1 <- plot_features(explanation_caret)

plot_explanations(explanation_caret)

All patients

all_patients = unique(all.df$ID)
for (i in 1:length(all_patients)) {
  patientID <- which(all.df$ID == all_patients[i])
  
  explanation_caret <- explain(
    x = all.df2[patientID,], 
    explainer = explainer_caret, 
    n_permutations = 5000,
    dist_fun = "gower",
    kernel_width = .75,
    n_features = 10, 
    feature_select = "highest_weights",
    labels = "yes"
  )
  
  p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", all_patients[i]))
  print(p1)
  #ggsave(paste0("./lime_plots/patient_",i,".pdf"), plot = p1)
}

Selected patients

Patient 8 (ID = 5)

j = 5
patientID <- which(all.df$ID == j)

## Print one line of table to check
all.df[patientID[1], ]
##   ID        dot        dor txage relapse sex rstatprtx hla tbi       gmgp agvhd
## 5  5 2013-08-27 2014-09-08  18.4     yes   M       CR2   0 yes BM/PB,0,no    no
##   cgvhd test      bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15 bmc_cd34
## 5   yes D1MC 2013-09-24    28   377  yes      99      83      100       98
##        pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 5 2013-09-24    28      96      77       98       93
for (i in patientID) {
  explanation_caret <- explain(
    x = all.df2[i,], 
    explainer = explainer_caret, 
    n_permutations = 5000,
    dist_fun = "gower",
    kernel_width = .75,
    n_features = 10, 
    feature_select = "highest_weights",
    labels = "yes"
  )
  p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
  print(p1)
}

Patient 15 (ID = 43)

j = 43
patientID <- which(all.df$ID == j)

## Print one line of table to check
all.df[patientID[1], ]
##    ID        dot        dor txage relapse sex rstatprtx hla tbi       gmgp
## 89 43 2018-08-10 2020-04-21   6.5     yes   M       CR2   5 yes PB,≤5, yes
##    agvhd cgvhd test      bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15
## 89    no    no D2MC 2018-10-10    61   620  yes     100     100       99
##    bmc_cd34      pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 89      100 2018-10-10    61     100     100      100      100
for (i in patientID) {
  explanation_caret <- explain(
    x = all.df2[i,], 
    explainer = explainer_caret, 
    n_permutations = 5000,
    dist_fun = "gower",
    kernel_width = .75,
    n_features = 10, 
    feature_select = "highest_weights",
    labels = "yes"
  )
  p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
  print(p1)
}

Patient 16 (ID = 8)

j = 8
patientID <- which(all.df$ID == j)

## Print one line of table to check
all.df[patientID[1], ]
##   ID        dot        dor txage relapse sex rstatprtx hla tbi       gmgp agvhd
## 8  8 2018-09-18 2019-12-10  15.8     yes   M       CR1   0 yes BM/PB,0,no   yes
##   cgvhd test      bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15 bmc_cd34
## 8    no D1MC 2018-10-16    28   448  yes     100      99      100      100
##        pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 8 2018-10-16    28     100     100       99      100
for (i in patientID) {
  explanation_caret <- explain(
    x = all.df2[i,], 
    explainer = explainer_caret, 
    n_permutations = 5000,
    dist_fun = "gower",
    kernel_width = .75,
    n_features = 10, 
    feature_select = "highest_weights",
    labels = "yes"
  )
  p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
  print(p1)
}

Patient 17 (ID = 45)

j = 45
patientID <- which(all.df$ID == j)

## Print one line of table to check
all.df[patientID[1], ]
##    ID        dot        dor txage relapse sex rstatprtx hla tbi       gmgp
## 45 45 2018-11-07 2019-04-29   9.9     yes   M       CR3   4 yes PB,≤5, yes
##    agvhd cgvhd test      bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15
## 45    no    no D1MC 2018-12-06    29   173  yes     100      99      100
##    bmc_cd34      pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 45      100 2018-11-23    16     100     100      100      100
for (i in patientID) {
  explanation_caret <- explain(
    x = all.df2[i,], 
    explainer = explainer_caret, 
    n_permutations = 5000,
    dist_fun = "gower",
    kernel_width = .75,
    n_features = 10, 
    feature_select = "highest_weights",
    labels = "yes"
  )
  p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
  print(p1)
}


  1. Stanford Medicine, ↩︎

  2. University of Utah, ↩︎